"""
STAR dataset, fixed epsilon method
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime

warnings.filterwarnings('ignore')


class OutputLogger:
    def __init__(self, filename):
        """
        Initialize logger with output file.

        Parameters:
        -----------
        filename : str
            Output file path
        """
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        """Write message to both terminal and file."""
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        """Flush both output streams."""
        self.terminal.flush()
        self.log.flush()

    def close(self):
        """Close file output stream."""
        self.log.close()


class CATEAllocationAlgorithm:
    """
    Implementation of the theoretical CATE allocation algorithm.

    This class implements the allocation algorithm with fixed gamma parameter and
    configurable heavy interval threshold.
    """

    def __init__(self, epsilon=0.1, gamma=0.5, delta=0.05, heavy_multiplier=1.6, random_seed=42):
        """
        Initialize the allocation algorithm.

        Parameters:
        -----------
        epsilon : float, default=0.1
            Approximation parameter for (1-ε)-optimal allocation
        gamma : float, default=0.5
            Fixed parameter controlling relationship between ε and ρ
        delta : float, default=0.05
            Confidence parameter for Hoeffding's inequality
        heavy_multiplier : float, default=1.6
            Threshold multiplier for heavy interval detection
        random_seed : int, default=42
            Random seed for reproducibility
        """
        self.epsilon = epsilon
        self.gamma = gamma
        self.rho = gamma * np.sqrt(epsilon)
        self.delta = delta
        self.heavy_multiplier = heavy_multiplier
        self.random_seed = random_seed
        np.random.seed(random_seed)

        print(f"CATE Allocation Algorithm")
        print(f"ε = {epsilon}")
        print(f"√ε = {np.sqrt(epsilon):.6f}")
        print(f"γ = {gamma}")
        print(f"ρ = γ√ε = {self.rho:.6f}")
        print(f"Heavy multiplier = {heavy_multiplier}x")
        print(f"δ = {delta}")
        print("="*60)

    def process_star_data(self, df, outcome_col=None):
        """
        Process STAR dataset.

        Parameters:
        -----------
        df : pandas.DataFrame
            Raw STAR dataset
        outcome_col : str, optional
            Custom outcome column name

        Returns:
        --------
        pandas.DataFrame
            Processed dataset ready for analysis
        """
        print(f"Processing STAR data with {len(df)} observations")

        df_processed = df.copy()

        # Validate required columns
        required_cols = ['gkschid', 'gkclasstype']
        missing = [col for col in required_cols if col not in df_processed.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # Filter treatment groups
        print(f"Original class type distribution: {df_processed['gkclasstype'].value_counts().to_dict()}")
        df_processed = df_processed[df_processed['gkclasstype'] != 'REGULAR + AIDE CLASS']
        print(f"After excluding aide classes: {len(df_processed)} observations")

        # Create binary treatment indicator
        treatment_map = {'SMALL CLASS': 1, 'REGULAR CLASS': 0}
        df_processed['treatment'] = df_processed['gkclasstype'].map(treatment_map)

        # Construct composite outcome measure
        test_components = ['gktreadss', 'gktmathss', 'gktlistss', 'gkwordskillss']
        available_components = [col for col in test_components if col in df_processed.columns]

        if not available_components:
            raise ValueError("No test score components found")

        # Remove observations with missing outcomes
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=available_components)
        print(f"Dropped {initial_size - len(df_processed)} rows due to missing test scores")

        df_processed['total_score'] = df_processed[available_components].sum(axis=1)
        df_processed['outcome'] = df_processed['total_score']

        # Final data cleaning
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['treatment', 'gkschid'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing treatment/school data")

        print(f"Final dataset: {final_size} students")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")

        return df_processed

    def create_school_groups(self, df, min_size=6):
        """Create groups based on school identifiers."""
        print(f"Creating school-based groups (min_size={min_size})")

        groups = []
        for school_id in df['gkschid'].unique():
            indices = df[df['gkschid'] == school_id].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'school_{school_id}',
                    'indices': indices,
                    'type': 'school'
                })

        print(f"Raw groups created: {len(groups)}")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        print(f"Balanced groups after filtering: {len(balanced_groups)}")
        return balanced_groups

    def create_ml_prediction_groups(self, df, n_groups=30, min_size=6):
        """
        Create groups using machine learning-based treatment effect prediction.

        Uses separate Random Forest models for treated and control outcomes,
        then clusters based on predicted treatment effects and covariates.
        """
        print(f"Creating ML prediction-based groups (target: {n_groups})")

        # Select baseline features
        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome', 'total_score']
                       and not col.startswith('gkt')]

        X = df[feature_cols].copy()

        # Preprocess features
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Train separate outcome models
        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() == 0 or control_mask.sum() == 0:
            print("Insufficient treated or control observations")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        # Predict treatment effects and perform clustering
        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'ml_prediction_{i}',
                    'indices': indices,
                    'type': 'ml_prediction'
                })

        print(f"Created {len(groups)} ML prediction-based groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score stratification."""
        print(f"Creating propensity score groups (target: {n_groups})")

        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome', 'total_score']]

        X = df[feature_cols].copy()

        # Preprocess features
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Estimate propensity scores with cross-validation
        prop_scores = cross_val_predict(
            LogisticRegression(random_state=self.random_seed),
            X, df['treatment'], method='predict_proba', cv=5
        )[:, 1]

        # Create quantile-based strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity score groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_performance_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on baseline academic performance."""
        print(f"Creating performance groups (target: {n_groups})")

        # Identify baseline test score columns
        score_cols = [col for col in df.columns if col.startswith('gkt') and 'ss' in col]
        if not score_cols:
            print("No baseline scores found")
            return []

        baseline_score = df[score_cols].fillna(df[score_cols].mean()).mean(axis=1)

        # Create percentile-based groups
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(baseline_score, percentiles)
        bins = np.digitize(baseline_score, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'performance_{i}',
                    'indices': indices,
                    'type': 'performance'
                })

        print(f"Created {len(groups)} performance groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_demographics_groups(self, df, min_size=6, feature_cols=None):
        """Create groups based on demographic characteristic combinations."""
        print(f"Creating demographics groups")

        if feature_cols is None:
            potential_features = ['gkfreelunch', 'race', 'gender', 'birthyear']
        else:
            potential_features = feature_cols

        # Identify available demographic features
        available_features = []
        for col in potential_features:
            if col in df.columns and df[col].notna().sum() > 0:
                available_features.append(col)

        if len(available_features) == 0:
            print("No demographic variables found, using school grouping")
            return self.create_school_groups(df, min_size)

        print(f"Using features: {available_features}")

        # Remove observations with missing demographic data
        df_clean = df[available_features].dropna()
        print(f"After removing missing values: {len(df_clean)}/{len(df)} students")

        if len(df_clean) == 0:
            return self.create_school_groups(df, min_size)

        # Create groups based on unique characteristic combinations
        unique_combinations = df_clean.drop_duplicates()
        print(f"Found {len(unique_combinations)} unique combinations")

        groups = []
        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Apply treatment balance requirements and compute CATEs."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            # Apply balance and minimum size constraints
            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            # Compute CATE as difference in means
            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1] interval."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def plot_cate_distribution(self, groups, title_suffix=""):
        """Visualize CATE distribution before and after normalization."""
        original_cates = [g['cate'] for g in groups]
        normalized_cates = [g['normalized_cate'] for g in groups]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        ax1.hist(original_cates, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Original CATE')
        ax1.set_ylabel('Frequency')
        ax1.set_title(f'Original CATE Distribution{title_suffix}')
        ax1.grid(True, alpha=0.3)

        ax2.hist(normalized_cates, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax2.set_xlabel('Normalized CATE (τ)')
        ax2.set_ylabel('Frequency')
        ax2.set_title(f'Normalized CATE Distribution{title_suffix}')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def estimate_tau(self, true_tau, accuracy):
        """
        Estimate tau using Hoeffding's inequality with Bernoulli samples.

        Parameters:
        -----------
        true_tau : float
            True normalized CATE value
        accuracy : float
            Desired estimation accuracy

        Returns:
        --------
        tuple
            (estimate, sample_size) pair
        """
        sample_size = int(np.ceil(np.log(2/self.delta) / (2 * accuracy**2)))
        samples = np.random.binomial(1, true_tau, sample_size)
        return np.mean(samples), sample_size

    def run_single_trial(self, groups, epsilon_val, trial_seed):
        """
        Execute one trial of the allocation algorithm.

        Parameters:
        -----------
        groups : list
            List of groups with normalized CATEs
        epsilon_val : float
            Current epsilon value for analysis
        trial_seed : int
            Trial-specific random seed

        Returns:
        --------
        tuple
            (trial_results, tau_estimates) containing performance metrics
        """
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])
        rho = self.gamma * np.sqrt(epsilon_val)

        # Estimate all tau values using rho-level accuracy
        tau_estimates_rho = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, rho)
            tau_estimates_rho.append(estimate)
        tau_estimates_rho = np.array(tau_estimates_rho)

        # Also estimate using epsilon-level accuracy for comparison
        tau_estimates_eps = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, epsilon_val)
            tau_estimates_eps.append(estimate)
        tau_estimates_eps = np.array(tau_estimates_eps)

        results = []

        # Evaluate all possible budget levels
        for K in range(1, n_groups):
            # Compute optimal allocation
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_value = np.sum(tau_true[optimal_indices])

            # Compute algorithm allocations
            rho_indices = np.argsort(tau_estimates_rho)[-K:]
            rho_value = np.sum(tau_true[rho_indices])

            eps_indices = np.argsort(tau_estimates_eps)[-K:]
            eps_value = np.sum(tau_true[eps_indices])

            # Calculate performance ratios
            rho_ratio = rho_value / optimal_value if optimal_value > 0 else 0
            eps_ratio = eps_value / optimal_value if optimal_value > 0 else 0
            rho_success = rho_ratio >= (1 - epsilon_val)
            eps_success = eps_ratio >= (1 - epsilon_val)

            # Detect heavy intervals
            tau_k_est = tau_estimates_rho[rho_indices[0]]
            a2_lower = tau_k_est
            a2_upper = tau_k_est + 2 * rho
            units_in_a2 = np.sum((tau_estimates_rho >= a2_lower) & (tau_estimates_rho <= a2_upper))
            expected_a2 = 2 * rho * n_groups
            is_heavy = units_in_a2 > self.heavy_multiplier * expected_a2

            results.append({
                'K': K,
                'optimal_value': optimal_value,
                'rho_value': rho_value,
                'eps_value': eps_value,
                'rho_ratio': rho_ratio,
                'eps_ratio': eps_ratio,
                'rho_success': rho_success,
                'eps_success': eps_success,
                'is_heavy': is_heavy,
                'tau_k_est': tau_k_est,
                'units_in_a2': units_in_a2
            })

        return results, tau_estimates_rho

    def find_recovery_units(self, K, tau_true, tau_estimates, epsilon_val):
        """Determine additional units needed to achieve target performance."""
        n_groups = len(tau_true)

        # Current allocation using rho estimates
        rho_indices = np.argsort(tau_estimates)[-K:]
        optimal_value = np.sum(tau_true[np.argsort(tau_true)[-K:]])

        # Remaining candidates sorted by estimated quality
        remaining_indices = np.argsort(tau_estimates)[:-K][::-1]

        # Test incremental expansion
        for extra in range(1, 11):
            if extra > len(remaining_indices):
                break

            expanded_indices = np.concatenate([rho_indices, remaining_indices[:extra]])
            expanded_value = np.sum(tau_true[expanded_indices])

            if expanded_value / optimal_value >= (1 - epsilon_val):
                return extra

        return None

    def find_closest_working_budget(self, failed_K, trial_results):
        """Find proximity to successful budget levels."""
        working_budgets = [r['K'] for r in trial_results if r['rho_success']]

        if not working_budgets:
            return None, None

        # Distance to any working budget
        distances_any = [abs(K - failed_K) for K in working_budgets]
        min_distance_any = min(distances_any)

        # Distance to smaller working budget
        smaller_working = [K for K in working_budgets if K < failed_K]
        if smaller_working:
            min_distance_smaller = failed_K - max(smaller_working)
        else:
            min_distance_smaller = None

        return min_distance_any, min_distance_smaller

    def analyze_method(self, groups, epsilon_val, n_trials=30):
        """
        Analyze algorithm performance for a single grouping method.

        Parameters:
        -----------
        groups : list
            List of groups with normalized CATEs
        epsilon_val : float
            Current epsilon value
        n_trials : int, default=30
            Number of trials to run

        Returns:
        --------
        list
            Trial data with performance metrics
        """
        print(f"\nAnalyzing {len(groups)} groups with ε={epsilon_val}, γ={self.gamma}")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        trial_data = []

        for trial in range(n_trials):
            print(f"Trial {trial + 1}/{n_trials}...")

            # Execute single trial
            trial_results, tau_estimates = self.run_single_trial(groups, epsilon_val, trial)

            # Analyze failure patterns
            failed_results = [r for r in trial_results if not r['rho_success']]
            failed_budgets = [r['K'] for r in failed_results]

            # Heavy interval analysis with both estimated and true values
            failed_heavy_estimated = []
            failed_heavy_true = []
            rho = self.gamma * np.sqrt(epsilon_val)

            for failed_result in failed_results:
                K = failed_result['K']
                failed_heavy_estimated.append(failed_result['is_heavy'])

                # Check heavy intervals using true tau values
                tau_k_true = tau_true[np.argsort(tau_true)[-K:]][0]
                a2_lower_true = tau_k_true
                a2_upper_true = tau_k_true + 2 * rho
                units_in_a2_true = np.sum((tau_true >= a2_lower_true) & (tau_true <= a2_upper_true))
                expected_a2_true = 2 * rho * n_groups
                is_heavy_true = units_in_a2_true > self.heavy_multiplier * expected_a2_true
                failed_heavy_true.append(is_heavy_true)

            # Report trial summary
            print(f"  Failed budgets: {failed_budgets}")

            if len(failed_budgets) > 0:
                estimated_clean = [bool(x) for x in failed_heavy_estimated]
                true_clean = [bool(x) for x in failed_heavy_true]
                print(f"  Heavy intervals - Estimated: {estimated_clean}")
                print(f"  Heavy intervals - True τ_K:   {true_clean}")

            # Compute aggregate statistics
            total_heavy = sum(r['is_heavy'] for r in trial_results)
            failed_heavy = sum(r['is_heavy'] for r in failed_results)

            # Recovery analysis
            recovery_units = []
            distances_to_working_any = []
            distances_to_working_smaller = []

            for failed_result in failed_results:
                K = failed_result['K']

                # Compute recovery requirements
                recovery = self.find_recovery_units(K, tau_true, tau_estimates, epsilon_val)
                if recovery is not None:
                    recovery_units.append(recovery)

                # Compute distances to working budgets
                distance_any, distance_smaller = self.find_closest_working_budget(K, trial_results)
                if distance_any is not None:
                    distances_to_working_any.append(distance_any)
                if distance_smaller is not None:
                    distances_to_working_smaller.append(distance_smaller)

            trial_info = {
                'trial': trial,
                'failed_budgets': failed_budgets,
                'num_failures': len(failed_results),
                'total_heavy': total_heavy,
                'failed_heavy': failed_heavy,
                'failed_heavy_estimated': failed_heavy_estimated,
                'failed_heavy_true': failed_heavy_true,
                'recovery_units': recovery_units,
                'distances_to_working_any': distances_to_working_any,
                'distances_to_working_smaller': distances_to_working_smaller
            }

            trial_data.append(trial_info)

            # Print trial metrics
            print(f"  Failures: {len(failed_results)}, Total heavy: {total_heavy}, Failed heavy: {failed_heavy}")
            if recovery_units:
                print(f"  Recovery units: μ={np.mean(recovery_units):.1f}, med={np.median(recovery_units):.0f}, max={np.max(recovery_units)}")
            if distances_to_working_any:
                print(f"  Distance any: μ={np.mean(distances_to_working_any):.1f}, med={np.median(distances_to_working_any):.0f}, max={np.max(distances_to_working_any)}")
            if distances_to_working_smaller:
                print(f"  Distance smaller: μ={np.mean(distances_to_working_smaller):.1f}, med={np.median(distances_to_working_smaller):.0f}, max={np.max(distances_to_working_smaller)}")
            else:
                print(f"  Distance smaller: No smaller working budgets found")

        return trial_data

    def print_method_summary(self, method_name, trial_data, n_groups):
        """Generate comprehensive summary statistics for a method."""
        print(f"\n{'='*100}")
        print(f"SUMMARY - {method_name} - ALL BUDGETS")
        print("="*100)
        print(f"{'Fail μ':<7} {'Fail σ':<7} {'FailR% μ':<9} {'FailR% σ':<9} {'TotHvy':<8} {'FailHvy':<9} {'Rec μ':<7} {'Rec med':<8} {'Rec max':<8} {'DAny μ':<8} {'DAny σ':<10} {'DAny max':<10} {'DSmall μ':<10} {'DSmall σ':<12} {'DSmall max':<12}")
        print("-"*120)

        # Aggregate statistics across trials
        all_failures = [t['num_failures'] for t in trial_data]
        all_total_heavy = [t['total_heavy'] for t in trial_data]
        all_failed_heavy = [t['failed_heavy'] for t in trial_data]
        all_recovery = []
        all_distances_any = []
        all_distances_smaller = []

        for t in trial_data:
            all_recovery.extend(t['recovery_units'])
            all_distances_any.extend(t['distances_to_working_any'])
            all_distances_smaller.extend(t['distances_to_working_smaller'])

        # Compute summary statistics
        avg_failures = np.mean(all_failures)
        std_failures = np.std(all_failures)
        avg_failure_rate = avg_failures / (n_groups - 1) * 100
        std_failure_rate = std_failures / (n_groups - 1) * 100
        avg_total_heavy = np.mean(all_total_heavy)
        avg_failed_heavy = np.mean(all_failed_heavy)

        # Recovery statistics
        if all_recovery:
            recovery_mean = np.mean(all_recovery)
            recovery_med = np.median(all_recovery)
            recovery_max = np.max(all_recovery)
        else:
            recovery_mean = recovery_med = recovery_max = np.nan

        # Distance statistics - any direction
        if all_distances_any:
            distance_any_mean = np.mean(all_distances_any)
            distance_any_std = np.std(all_distances_any)
            distance_any_max = np.max(all_distances_any)
        else:
            distance_any_mean = distance_any_std = distance_any_max = np.nan

        # Distance statistics - smaller budgets only
        if all_distances_smaller:
            distance_smaller_mean = np.mean(all_distances_smaller)
            distance_smaller_std = np.std(all_distances_smaller)
            distance_smaller_max = np.max(all_distances_smaller)
        else:
            distance_smaller_mean = distance_smaller_std = distance_smaller_max = np.nan

        print(f"{avg_failures:<7.1f} {std_failures:<7.1f} {avg_failure_rate:<9.1f} {std_failure_rate:<9.1f} {avg_total_heavy:<8.1f} {avg_failed_heavy:<9.1f} "
              f"{recovery_mean:<7.1f} {recovery_med:<8.0f} {recovery_max:<8.0f} "
              f"{distance_any_mean:<8.1f} {distance_any_std:<10.1f} {distance_any_max:<10.0f} "
              f"{distance_smaller_mean:<10.1f} {distance_smaller_std:<12.1f} {distance_smaller_max:<12.0f}")

        return {
            'avg_failures': avg_failures,
            'failure_rate_pct': avg_failure_rate,
            'avg_recovery': recovery_mean,
            'n_groups': n_groups
        }


def run_comprehensive_analysis(df_star, epsilon_values=None, n_trials=30, log_file=None):
    """
    Execute comprehensive analysis across all grouping methods.

    Parameters:
    -----------
    df_star : pandas.DataFrame
        STAR dataset
    epsilon_values : list, optional
        Epsilon values to test
    n_trials : int, default=30
        Number of trials per configuration
    log_file : str, optional
        Output log file path

    Returns:
    --------
    tuple
        (all_results, summary_data) containing comprehensive analysis results
    """

    if epsilon_values is None:
        epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    # Configure logging
    if log_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f"comprehensive_analysis_gamma05_{timestamp}.txt"

    # Redirect output to both console and file
    original_stdout = sys.stdout
    logger = OutputLogger(log_file)
    sys.stdout = logger

    try:
        print("COMPREHENSIVE ALGORITHMIC ANALYSIS")
        print(f"Configuration: γ=0.5, Heavy threshold=1.6x")
        print(f"Log file: {log_file}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*100)

        # Define grouping methods for analysis
        methods = [
            ('School', lambda allocator, df: allocator.create_school_groups(df, min_size=6)),
            ('Demographics', lambda allocator, df: allocator.create_demographics_groups(df,
                                                   feature_cols=['gkfreelunch', 'race', 'gender'], min_size=6)),
            ('ML Prediction 30', lambda allocator, df: allocator.create_ml_prediction_groups(df, n_groups=30, min_size=6)),
            ('ML Prediction 50', lambda allocator, df: allocator.create_ml_prediction_groups(df, n_groups=50, min_size=6)),
            ('Propensity Score', lambda allocator, df: allocator.create_propensity_groups(df, n_groups=50, min_size=6)),
            ('Performance', lambda allocator, df: allocator.create_performance_groups(df, n_groups=50, min_size=6))
        ]

        all_results = {}

        for method_name, method_func in methods:
            print(f"\n{'='*120}")
            print(f"ANALYZING METHOD: {method_name}")
            print("="*120)

            method_results = []

            for eps in epsilon_values:
                print(f"\n{'='*100}")
                print(f"METHOD: {method_name} | EPSILON = {eps}")
                print("="*100)

                # Initialize algorithm with current parameters
                allocator = CATEAllocationAlgorithm(epsilon=eps, gamma=0.5, heavy_multiplier=1.6)
                df_processed = allocator.process_star_data(df_star)

                try:
                    # Create groups using current method
                    groups = method_func(allocator, df_processed)

                    if len(groups) < 3:
                        print(f"Insufficient groups ({len(groups)}) for {method_name} with ε = {eps} - skipping")
                        continue

                    groups = allocator.normalize_cates(groups)

                    # Display CATE distribution
                    allocator.plot_cate_distribution(groups, f" ({method_name}, ε={eps})")

                    # Execute algorithmic analysis
                    trial_data = allocator.analyze_method(groups, eps, n_trials)

                    # Generate method summary
                    stats = allocator.print_method_summary(method_name, trial_data, len(groups))

                    epsilon_result = {
                        'method': method_name,
                        'epsilon': eps,
                        'sqrt_epsilon': np.sqrt(eps),
                        'gamma': 0.5,
                        'rho': 0.5 * np.sqrt(eps),
                        'groups': groups,
                        'trial_data': trial_data,
                        'stats': stats
                    }

                    method_results.append(epsilon_result)

                except Exception as e:
                    print(f"Error with {method_name} at ε = {eps}: {e}")
                    continue

            all_results[method_name] = method_results

        # Generate comprehensive summary
        print(f"\n{'='*200}")
        print("COMPREHENSIVE SUMMARY - ALL METHODS AND EPSILON VALUES")
        print("="*200)

        summary_data = []

        for method_name, method_results in all_results.items():
            if not method_results:
                continue

            print(f"\n{'-'*100}")
            print(f"METHOD: {method_name}")
            print("-"*100)

            for eps_result in method_results:
                eps = eps_result['epsilon']
                sqrt_eps = eps_result['sqrt_epsilon']
                gamma = eps_result['gamma']
                rho = eps_result['rho']
                n_groups = len(eps_result['groups'])
                stats = eps_result['stats']

                summary_data.append({
                    'method': method_name,
                    'epsilon': eps,
                    'sqrt_eps': sqrt_eps,
                    'gamma': gamma,
                    'rho': rho,
                    'avg_failures': stats['avg_failures'],
                    'failure_rate_pct': stats['failure_rate_pct'],
                    'avg_recovery': stats['avg_recovery'],
                    'n_groups': stats['n_groups']
                })

            # Print method-specific table
            method_data = [d for d in summary_data if d['method'] == method_name]
            if method_data:
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for data in method_data:
                    print(f"{data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                          f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                          f"{data['avg_recovery']:<8.1f}")

        # Overall summary table
        print(f"\n{'='*200}")
        print("OVERALL SUMMARY TABLE - ALL METHODS COMBINED")
        print("="*200)
        print(f"{'Method':<18} {'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
        print("-" * 120)

        for data in summary_data:
            print(f"{data['method']:<18} {data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                  f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                  f"{data['avg_recovery']:<8.1f}")

        # Generate key insights
        print(f"\n{'='*100}")
        print("KEY INSIGHTS")
        print("="*100)

        if summary_data:
            # Method performance ranking
            method_performance = {}
            for method_name in all_results.keys():
                method_data = [d for d in summary_data if d['method'] == method_name]
                if method_data:
                    avg_failure_rate = np.mean([d['failure_rate_pct'] for d in method_data])
                    method_performance[method_name] = avg_failure_rate

            if method_performance:
                best_method = min(method_performance, key=method_performance.get)
                worst_method = max(method_performance, key=method_performance.get)

                print(f"Best performing method: {best_method}")
                print(f"  Average failure rate: {method_performance[best_method]:.1f}%")

                print(f"\nWorst performing method: {worst_method}")
                print(f"  Average failure rate: {method_performance[worst_method]:.1f}%")

                print(f"\nMethod ranking (by average failure rate):")
                sorted_methods = sorted(method_performance.items(), key=lambda x: x[1])
                for i, (method, rate) in enumerate(sorted_methods, 1):
                    print(f"  {i}. {method}: {rate:.1f}%")

        # Epsilon effect analysis
        print(f"\nEffect of epsilon parameter:")
        epsilon_performance = {}
        for eps in epsilon_values:
            eps_data = [d for d in summary_data if d['epsilon'] == eps]
            if eps_data:
                avg_failure_rate = np.mean([d['failure_rate_pct'] for d in eps_data])
                epsilon_performance[eps] = avg_failure_rate

        if epsilon_performance:
            print(f"{'Epsilon':<10} {'Avg Failure Rate':<15} {'ρ = 0.5√ε':<12}")
            print("-" * 40)
            for eps in sorted(epsilon_performance.keys()):
                rho = 0.5 * np.sqrt(eps)
                print(f"{eps:<10} {epsilon_performance[eps]:<15.1f} {rho:<12.6f}")

        print(f"\nAnalysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Results saved to: {log_file}")
        print(f"\nConfiguration summary:")
        print(f"- Fixed γ = 0.5")
        print(f"- Heavy interval threshold = 1.6× uniform expectation")
        print(f"- Trials per configuration = {n_trials}")
        print(f"- Methods tested = {len(methods)}")
        print(f"- Epsilon values tested = {len(epsilon_values)}")

        return all_results, summary_data

    finally:
        # Restore original stdout and close log
        sys.stdout = original_stdout
        logger.close()


if __name__ == "__main__":
    # Load STAR dataset
    df_star = pd.read_spss('STAR_Students.sav')

    # Configure analysis parameters
    epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    # Execute comprehensive analysis
    results, summary = run_comprehensive_analysis(
        df_star,
        epsilon_values=epsilon_values,
        n_trials=30,
        log_file="algorithmic_analysis_gamma05_heavy16.txt"
    )

    print(f"Analyzed {len(results)} grouping methods across {len(epsilon_values)} epsilon values.")
    print(f"Total configurations: {sum(len(method_results) for method_results in results.values())}")